using OrderedCollections: OrderedSet

struct Bijection{T, S}
     map::Dict{T, S}
     invmap::Dict{S, T}

     function Bijection(_map)
        _T = keytype(_map)
        _S = valtype(_map)
        _map = Dict{_T, _S}(k => v for (k, v) ∈ _map if k != v)
        _invmap = Dict{_S, _T}()
        for (k, v) ∈ _map
            if v ∈ keys(_invmap)
                error("Duplicate value $v for keys $(_invmap[v]) and $k")
            end
            _invmap[v] = k
        end
        return new{_T, _S}(_map, _invmap)
     end
end

Base.keytype(b::Bijection{T, S}) where {T, S} = T
Base.valtype(b::Bijection{T, S}) where {T, S} = S

Base.length(b::Bijection) = length(b.map)

Base.iterate(b::Bijection) = iterate(b.map)
Base.iterate(b::Bijection, i) = iterate(b.map, i)

function inv(b::Bijection{T, S}) where {T, S}
    return Bijection(b.invmap)
end

function Base.show(io::IO, b::Bijection{T, S}) where {T, S}
    print(io, "Bijection{$T, $S}($(b.map))")
end

function Base.getindex(b::Bijection, k)
    return Base.get(b.map, k, k)
end

function Base.:*(x::Bijection{T, S}, y::Bijection{R, T}) where {R, T, S}
    z = Dict{R, S}()
    for (v, u) ∈ x.map
        w = get(y.invmap, v, v)
        z[w] = u
    end
    for (w, v) ∈ y.map
        if v ∉ keys(x.map)
            z[w] = v
        end
    end
    return Bijection(z)
end

function Base.:*(x::Bijection{T, T}, y::Dict{T, S}) where {T, S}
    d = Dict{T, S}()
    for (k, v) ∈ y
        if k ∈ keys(x.map)
            d[x.map[k]] = v
        elseif k ∉ keys(d)
            d[k] = v
        end
    end
    return d
end

function Base.:*(x::Dict{T, S}, y::Bijection{T, T}) where {T, S}
    d = Dict{T, S}()
    for (k, v) ∈ x
        if k ∈ keys(y.map)
            d[x.map[k]] = v
        elseif k ∉ keys(d)
            d[k] = v
        end
    end
    return d
end

Base.:*(x::Bijection{T, S}, y::Vector{T}) where {T, S} = Vector{S}([x[k] for k ∈ y])
Base.:*(x::Vector{T}, y::Bijection{T, S}) where {T, S} = Vector{S}([y[k] for k ∈ x])

Base.:*(x::Bijection{T, S}, y::Set{T}) where {T, S} = Set{S}(x[k] for k ∈ y)
Base.:*(x::Set{T}, y::Bijection{T, S}) where {T, S} = Set{S}(y[k] for k ∈ x)

Base.:*(x::Bijection{T, S}, y::Tuple{T}) where {T, S} = Tuple{S}(x[k] for k ∈ y)
Base.:*(x::Tuple{T}, y::Bijection{T, S}) where {T, S} = Tuple{S}(y[k] for k ∈ x)

Base.:*(x::Bijection{T, S}, y::OrderedSet) where {T, S} = OrderedSet{S}(x[k] for k ∈ y)
Base.:*(x::OrderedSet, y::Bijection{T, S}) where {T, S} = OrderedSet{S}(y[k] for k ∈ x)

Base.:*(x::Bijection{T, S}, y::T) where {T, S} = x[y]
Base.:*(x::T, y::Bijection{T, S}) where {T, S} = y[x]

Base.:*(x::Bijection, y::Nothing) = nothing
Base.:*(x::Nothing, y::Bijection) = nothing

Base.isempty(x::Bijection) = isempty(x.map)
